exp_id=$1
num_shadow=$2

DATA_PATH="datasets"  # path to cifar10, e.g., datasets
EMB_PATH="src/diffusion/clip_embeddings"  # path to the extracted CLIP embeddings, e.g., src/diffusion/clip_embeddings
LORA_PATH="src/diffusion/checkpoints"  # path to save checkpoints, e.g., ./LoRA/checkpoint
LIRA_PATH="src/diffusion/lira"  # path to save in-out-split indices, noisy targets and canary indices

export MODEL_NAME="runwayml/stable-diffusion-v1-5"
methods=("gt_dm")
# Loop through the array and set the OUTPUT_DIR variable accordingly
for method in "${methods[@]}"; do
  export OUTPUT_DIR="${LORA_PATH}/${method}_${exp_id}/all"
  export LOG_DIR="${LORA_PATH}/logs"
  echo "Current OUTPUT_DIR: $OUTPUT_DIR, $exp_id"
  if [ -f "$OUTPUT_DIR/pytorch_lora_weights.bin" ] || [ -f "$OUTPUT_DIR/pytorch_lora_weights.safetensors" ]; then
    echo "Folder exists. Skipping script execution."
  else
    if [ "$method" == "gt_dm" ]; then
      accelerate launch --mixed_precision="no" src/diffusion/train_text_to_image_lora.py \
      --pretrained_model_name_or_path=$MODEL_NAME \
      --dataset_name=cifar10 \
      --report_to=tensorboard \
      --resolution=512 --random_flip \
      --train_batch_size=4 \
      --num_train_epochs=100 --checkpointing_steps=500 \
      --learning_rate=1e-04 --lr_scheduler="constant" \
      --seed=42 \
      --output_dir=$OUTPUT_DIR \
      --snr_gamma=5 \
      --guidance_token=8 \
      --dist_match=0.003 \
      --logging_dir $LOG_DIR \
      --image_column="img" \
      --caption_column="label" \
      --cache_dir=$DATA_PATH \
      --embedding_path=$EMB_PATH \
      --lira_path=$LIRA_PATH \
      --exp_id=$exp_id \
      --num_shadow=$num_shadow \
      --num_canaries=512
    else
      echo "Method not implemented"
    fi
    wait
    echo "All processes completed"
  fi
done